import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt 
import glob
import numpy as np
from matplotlib.ticker import ScalarFormatter
import matplotlib.ticker as ticker
import string
class ScalarFormatterForceFormat(ScalarFormatter):
    def _set_format(self):  
        self.format = "%0.2f" 
        
        
# =============================================================================
# Processing the DSSAT output to generate a dataframe for plots
# =============================================================================
path=r"\DSSAT outputs"    
all_files = glob.glob(path + "/*.txt")
my_cols=['TRT','PDAT',	'ADAT',	'MDAT',	'Yield'	,'SWC',	'DRCM',	'IR',	'Ng'	,'RU',	'TMAX'	,'TMIN',	'PRCP',	'ET',	'Tr']

Da=[]
for F in all_files:
    file=pd.read_csv(F,delimiter=r'\s+',names=my_cols,header=None,na_values=[-99])
    Da.append(file)
    
DSSAT = pd.concat(Da, axis=0, ignore_index=True,)
DSSAT.dropna(inplace=True)
DSSAT['YYcode']=DSSAT['TRT'].str[-1:].astype(int)
DSSAT['YYcode_interpret']=np.where((DSSAT['YYcode'] == 1) | (DSSAT['YYcode'] == 2), 0, 100)
DSSAT['YYYY']=DSSAT['PDAT'].astype(str).str[:4].astype(int)
DSSAT['year']=DSSAT['YYYY']+DSSAT['YYcode_interpret']
DSSAT.dtypes
DSSAT['GIS_code']=DSSAT['TRT'].str[4:-5]
DSSAT['CNCT']=DSSAT['TRT'].str[:4]
DSSAT['GCM']=DSSAT['TRT'].str[11:-1]
DSSAT['RCP']=DSSAT['TRT'].str[8:-2]
DSSAT['Country']=DSSAT['TRT'].str[:2]
DSSAT['City']=DSSAT['CNCT'].str.strip().replace(dict(zip(["USCa","ItEm" ,"ItFo", "ChIn", "ChGa","ChXi"], ["California","Emilia","Foggia","Inner Mongolia","Gansu","Xinjiang"])),regex=True)

gis_data=pd.read_csv(r'\GIS-Jan2021.csv')
DSSAT_lat_lon=DSSAT.merge(gis_data, how='left', left_on=['TRT'], right_on=['TRT'])

DSSAT_lat_lon.columns
#North_cal latitude starts from 35.8
NC_SC=37
DSSAT_lat_lon.loc[(DSSAT_lat_lon['LAT'] >NC_SC)&(DSSAT_lat_lon['LON'] <-114)]['City']
DSSAT_lat_lon['City'][(DSSAT_lat_lon['LAT'] > NC_SC)&(DSSAT_lat_lon['LON'] <-114)]="Northern California" 
DSSAT_lat_lon['City'][(DSSAT_lat_lon['LAT'] <= NC_SC)&(DSSAT_lat_lon['LON'] <-114)]="Southern California" 
col_list = ['year','City','GCM','RCP','Yield']
allcities=['California', 'Emilia','Foggia','Xinjiang','Gansu','Inner Mongolia']

# =============================================================================
# =============================================================================
# # Figure 1  --- the next lines are used to generate Figure 1 of the manuscript
# =============================================================================
# =============================================================================

F=38
B=800
plt.rcParams['xtick.labelsize'] = 36
plt.rcParams['axes.labelsize'] = F
plt.rcParams['ytick.labelsize'] = F
plt.rcParams['axes.labelsize'] = F
plt.rcParams["font.weight"] = B
plt.rcParams["axes.labelweight"] = B
plt.rc_context({'axes.edgecolor':'black', 'xtick.color':'black', 'ytick.color':'black'})

fig1, axes = plt.subplots(3, 2,figsize=[26.5,24.5],sharex=False,sharey=False,dpi=300)
fig1.subplots_adjust(hspace=0.2, wspace=0.10)
axs = axes.ravel()
n=0
agg=[]
for City in allcities:
    Citydata=DSSAT.loc[DSSAT['City']==City]
    Cityplot=Citydata[col_list]
    grouped_data=Cityplot.groupby(['RCP', 'year','GCM']).agg('mean').reset_index(level=[0,1,2])
    # if City=='Northern California':
    #     lbl= 'California (U.S.)'
    #     Area=73663
    # if City=='Southern California':
    #     lbl= 'California (U.S.)'
    #     Area=45108
    if City=='California':
        lbl= 'United States (CA)'
        Area=112772        
    if City=='Emilia':
        lbl= 'Italy (ER)'
        Area=24733
    if City=='Foggia':
        lbl= 'Italy (FG)'
        Area=24995
    if City=='Gansu':
        lbl= 'China (GA)'
        Area=1467
    if City=='Inner Mongolia':
        lbl= 'China (IM)'       
        Area=8967
    if City=='Xinjiang':
        lbl= 'China (XJ)'       
        Area=26476
    grouped_data['Production']=grouped_data['Yield']*Area/100000000
    grouped_data = grouped_data.drop('Yield', 1)
    agg.append(grouped_data)
    grouped_data.reset_index(drop=True)
    RCP26=grouped_data.loc[grouped_data['RCP']=='R26']
    RCP26['mv_avg']=RCP26.groupby('GCM')['Production'].transform(lambda x: x.rolling(10, 1).mean())
    RCP70=grouped_data.loc[grouped_data['RCP']=='R70']
    RCP70['mv_avg']=RCP70.groupby('GCM')['Production'].transform(lambda x: x.rolling(10, 1).mean())
    RCP85=grouped_data.loc[grouped_data['RCP']=='R85']
    RCP85['mv_avg']=RCP85.groupby('GCM')['Production'].transform(lambda x: x.rolling(10, 1).mean())
    axs[n] = sns.lineplot(x="year",y="mv_avg",linewidth=3,  color='#EA5F94',label="SSP3-7.0" ,data=RCP70,ax=axs[n])
    axs[n] = sns.lineplot(x="year",y="mv_avg",linewidth=3, color='#0b9681', label="SSP5-8.5",data=RCP85,ax=axs[n])
    axs[n] = sns.lineplot(x="year",y="mv_avg",linewidth=3, label="SSP1-2.6",color='#FFB14E', data=RCP26,ax=axs[n] )
    sns.despine(offset=15 );
    axs[n].legend("")
    axs[n].set_xlabel("")
    # plt.ylabel("Crop production (t) × 100000")
    axs[n].set_ylabel("")
    axs[n].annotate("("+string.ascii_uppercase[n]+") "+lbl, xy=(0.02, 0.03), xycoords='axes fraction',fontsize=38)
    yfmt = ScalarFormatterForceFormat()
    axs[n].yaxis.set_major_formatter(yfmt) 
    plt.rc('grid', linestyle="--", color='grey')
    axs[n].grid()
    plt.setp(axs[n].spines.values(), linewidth=2)
    plt.tight_layout()
    n=n+1
# =============================================================================
# generate the global production timeseries -- Figure 1        
# =============================================================================
Agg_data = pd.concat(agg, axis=0, ignore_index=True)
grouped_aggdata=Agg_data.groupby(['RCP', 'year','GCM']).agg(['sum']).reset_index(level=[0,1,2])
grouped_aggdata.columns = grouped_aggdata.columns.get_level_values(0)
RCP26=grouped_aggdata.loc[grouped_aggdata['RCP']=='R26']
RCP26['mv_avg']=RCP26.groupby('GCM')['Production'].transform(lambda x: x.rolling(10, 1).mean())
RCP70=grouped_aggdata.loc[grouped_aggdata['RCP']=='R70']
RCP70['mv_avg']=RCP70.groupby('GCM')['Production'].transform(lambda x: x.rolling(10, 1).mean())
RCP85=grouped_aggdata.loc[grouped_aggdata['RCP']=='R85']
RCP85['mv_avg']=RCP85.groupby('GCM')['Production'].transform(lambda x: x.rolling(10, 1).mean())
fig, ax = plt.subplots(1, 1,figsize=[13,10])
plt.rc_context({'axes.edgecolor':'black', 'xtick.color':'black', 'ytick.color':'black'})
ax = sns.lineplot(x="year",y="mv_avg",linewidth=3, label="SSP1-2.6",color='#FFB14E', data=RCP26 )
ax = sns.lineplot(x="year",y="mv_avg",linewidth=3,  color='#EA5F94',label="SSP3-7.0" ,data=RCP70 )
ax = sns.lineplot(x="year",y="mv_avg",linewidth=3, color='#0b9681', label="SSP5-8.5",data=RCP85)
sns.despine(offset=15 );
plt.legend(bbox_to_anchor=(0.38, -0.05),frameon=False,
                     loc="lower right",fontsize=30)
plt.xlabel("")
plt.ylabel("Crop production (t) × 100,000")
ax.annotate("Global trend ", xy=(0.02, 0.95), xycoords='axes fraction',fontsize=40)
yfmt = ScalarFormatterForceFormat()
ax.yaxis.set_major_formatter(yfmt)
plt.rc('grid', linestyle="--", color='grey')
plt.setp(ax.spines.values(), linewidth=2)
ax.grid() 
plt.tight_layout()

# =============================================================================
# =============================================================================
# Figure 2 --- correlation between yield , temp, rainfall [YTR]
# =============================================================================
# =============================================================================
from scipy.stats import gaussian_kde
def myfmt(x, pos):
    return '{0:.2f}'.format(x)
DSSAT.columns
col_list_YTR= ['TRT','year','GIS_code','Country','CNCT','GCM','RCP','Yield','PRCP','TMAX','TMIN']
gis_data=pd.read_csv('\GIS-Jan2021.csv')
YTR= DSSAT[col_list_YTR]
YTR=YTR.merge(gis_data, how='left', left_on=['TRT'], right_on=['TRT'])
YTR['Tmean']=(YTR['TMAX']+YTR['TMIN'])/2
YTR['City']=YTR['CNCT'].str[2:]
YTR['Yield']=YTR['Yield']/1000
YTR['City']=YTR['City'].str.strip().replace(dict(zip(["Ca","Em" ,"Fo", "In", "Ga","Xi"], ["California","Emilia","Foggia","Inner Mongolia","Gansu","Xinjiang"])),regex=True)
YTR['Period']="2070-2099"
YTR['Period'][YTR["year"] <2010]="1980-2009" 
YTR['Period'][(YTR["year"] >= 2010)&(YTR["year"] <2040)]="2010-2039"
YTR['Period'][(YTR["year"] >= 2040)&(YTR["year"] <2070)]="2040-2069" 
allCities=['California', 'Foggia', 'Emilia', 'Xinjiang', 'Gansu','Inner Mongolia']
allCities_lbls=['United States\n(CA)', 'Italy\n(FG)','Italy\n(ER)', 'China\n(XJ)',  'China\n(GA)','China\n(IM)']

fig1, axes = plt.subplots(3, 6,figsize=[45,20],sharex=True,sharey=True,dpi=300)
cols = allCities_lbls
rows = ["SSP1-2.6","SSP3-7.0","SSP5-8.5"]
pad = 20 # in points
for ax, col in zip(axes[0], cols):
    ax.annotate(col, xy=(0.5, 1), xytext=(0, pad),
                xycoords='axes fraction', textcoords='offset points',
                size=45,weight='bold', ha='center', va='baseline')

for ax, row in zip(axes[:,0], rows):
    ax.annotate(row, xy=(6.73, 0.5), xytext=(-ax.yaxis.labelpad + pad, 0),
                xycoords='axes fraction', textcoords='offset points',
                size=58, weight='bold',ha='right', va='center',rotation=90)

bottom, top = 0.1, 0.9
left, right = 0.1, 0.8
fig1.subplots_adjust(top=top, bottom=bottom, left=left, right=right, hspace=0.05, wspace=0.10)
plt.setp(ax, xlim=(7, 35))
plt.setp(ax, ylim=(-3, 11))


B=800
plt.rcParams['xtick.labelsize'] = 50
plt.rcParams['ytick.labelsize'] = 50
plt.rcParams['axes.titleweight']='bold'
plt.rcParams["font.weight"] = B
plt.rcParams["axes.labelweight"] = B
n=0
for cty in allCities:
    YRTCity=YTR.loc[YTR['City']==cty]
    YRTCity_RCP1=YRTCity.loc[YRTCity['RCP']=="R26"]
    YRTCity_RCP2=YRTCity.loc[YRTCity['RCP']=="R70"]
    YRTCity_RCP3=YRTCity.loc[YRTCity['RCP']=="R85"]
    YRTCity_RCP1_grouped=YRTCity_RCP1.groupby(['year','GCM','LAT']).agg('mean').reset_index(level=[0,1,2])
    YRTCity_RCP2_grouped=YRTCity_RCP2.groupby(['year','GCM','LAT']).agg('mean').reset_index(level=[0,1,2])
    YRTCity_RCP3_grouped=YRTCity_RCP3.groupby(['year','GCM','LAT']).agg('mean').reset_index(level=[0,1,2])
    dataframes=[YRTCity_RCP1_grouped,YRTCity_RCP2_grouped,YRTCity_RCP3_grouped]
    for d, dataframe in enumerate (dataframes):
        
        x = np.array(dataframe['Tmean'].tolist())
        y = np.array(dataframe['Yield'].tolist())
        xy = np.vstack([x, y])
        z = gaussian_kde(xy)(xy)  
        idx = z.argsort()
        x, y, z = x[idx], y[idx], z[idx]    
        im=axes[d,n].scatter(x, y,edgecolors = 'none',c = z,s = 18,cmap=plt.cm.get_cmap('jet'),vmin=0,vmax=0.15,label =cty)
        # axes[d,n].legend() # Uncomment for double check the names
        plt.setp(axes[d,n].spines.values(), linewidth=2)
        axes[d,n].grid(True,linestyle="--", color='grey')
        if (n==0) & (d==1):
            axes[d,n].set_ylabel('Yield (t DM/ha)',fontsize=65)
        if (d==2) & (n==3):
            axes[d,n].set_xlabel('Mean air temperature ($^oC$)', labelpad=20,fontsize=60)
    n=n+1
    plt.setp(ax.spines.values(), linewidth=2)
cb=plt.colorbar(ax.get_children()[0], ax=axes.ravel().tolist(),pad=0.06,orientation = 'vertical')
cb.ax.set_title('Density',size=50,pad=20, weight='bold')
plt.show()


# =============================================================================
# =============================================================================
# Figure 3 --- Density plots --- low yielding plots 
# =============================================================================
# =============================================================================
import joypy
import scipy.stats
import matplotlib
import ptitprince as pt
from matplotlib.collections import PathCollection
from matplotlib.legend_handler import HandlerPathCollection, HandlerLine2D
class ScalarFormatterForceFormat(ScalarFormatter):
    def _set_format(self):  # Override function that finds format to use.
        self.format = "%0.2f"  # Give format here
# plt.figure(dpi= 380)
col_list_ridge = ['year','Country','GCM','RCP','Yield']
ridgeplot_data= DSSAT[col_list_ridge]
ridgeplot_data['Country']=ridgeplot_data['Country'].str.strip().replace(dict(zip(["It", "US", "Ch"], ["Italy","United States","China"])),regex=True)
ridgeplot_data["Yield"]=ridgeplot_data["Yield"]/1000

ridgeplot_data['Period']="2070-2099"
ridgeplot_data['Period'][ridgeplot_data["year"] <2010]="1980-2009" 
ridgeplot_data['Period'][(ridgeplot_data["year"] >= 2010)&(ridgeplot_data["year"] <2040)]="2010-2039"
ridgeplot_data['Period'][(ridgeplot_data["year"] >= 2040)&(ridgeplot_data["year"] <2070)]="2040-2069" 
ridgeplot_data['RCP']=ridgeplot_data['RCP'].str.strip().replace(dict(zip(["R26", "R70", "R85"], ["RCP 2.6","RCP 7.0","RCP 8.5"])),regex=True)

countries = ["United States","Italy","China"]

class ScalarFormatterForceFormat(ScalarFormatter):
    def _set_format(self):  # Override function that finds format to use.
        self.format = "%0.1f"  # Give format here
RCPs=["RCP 2.6","RCP 7.0","RCP 8.5"]
n=0
for cnts in countries:  
    ridgeplot_data_i=ridgeplot_data.loc[ridgeplot_data['Country']==cnts]
    figure, axes = plt.subplots(3, 1,figsize=[11,14],sharex=True,sharey=True,dpi=100)
    figure.subplots_adjust( hspace=0.03)
    r=0
    for rcp in RCPs:
        ridgeplot_data_ii=ridgeplot_data_i.loc[ridgeplot_data_i['RCP']==rcp]

        x1 = ridgeplot_data_ii.loc[ridgeplot_data_ii["Period"]=='1980-2009', 'Yield']
        x2 = ridgeplot_data_ii.loc[ridgeplot_data_ii["Period"]=='2010-2039', 'Yield']
        x3 = ridgeplot_data_ii.loc[ridgeplot_data_ii["Period"]=='2040-2069', 'Yield']
        x4 = ridgeplot_data_ii.loc[ridgeplot_data_ii["Period"]=='2070-2099', 'Yield']
    
        kwargs = dict(hist_kws={'alpha':.01}, kde_kws={'linewidth':2})
        plt.rcParams['xtick.labelsize'] = 35
        plt.rcParams['ytick.labelsize'] = 35

        axes[r]=sns.distplot(x1,bins=10, color="dodgerblue", label="1980-2009", **kwargs,ax=axes[r])
        axes[r]=sns.distplot(x2,bins=10, color="#E58606", label="2010-2039", **kwargs,ax=axes[r])
        axes[r]=sns.distplot(x3,bins=10, color="deeppink", label="2040-2069", **kwargs,ax=axes[r])
        axes[r]=sns.distplot(x4,bins=10, color="#94346E", label="2070-2099", **kwargs,ax=axes[r])
    
        l1 = axes[r].lines[0];l2 = axes[r].lines[1];l3 = axes[r].lines[2];l4= axes[r].lines[3]
        x1 = l1.get_xydata()[:,0];y1 = l1.get_xydata()[:,1];x2 = l2.get_xydata()[:,0];y2 = l2.get_xydata()[:,1]
        x3 = l3.get_xydata()[:,0];y3 = l3.get_xydata()[:,1];x4 = l4.get_xydata()[:,0];y4 = l4.get_xydata()[:,1]
        axes[r].fill_between(x1,y1, color="dodgerblue", alpha=0.3)
        axes[r].fill_between(x2,y2, color="#E58606", alpha=0.3)
        axes[r].fill_between(x3,y3, color="deeppink", alpha=0.3)
        axes[r].fill_between(x4,y4, color="#94346E", alpha=0.3)
 
        if (r==1):
            axes[r].set_ylabel('Density',labelpad=30,fontsize=50)
        axes[r].set_xlabel('')

        if (r==2):
            axes[r].set_xlabel('Simulated Yield (t DM/ ha)',labelpad=30,fontsize=45)
        yfmt = ScalarFormatterForceFormat()
        axes[r].yaxis.set_major_formatter(yfmt)
        axes[r].xaxis.set_major_formatter(yfmt)
        yticks = ticker.MaxNLocator(3) ################## set number of y ticks manually
        axes[r].yaxis.set_major_locator(yticks)
        plt.rc('grid', linestyle="--", color='grey')
        axes[r].grid()

        if (r==0)&(n==0):
            # axes[r].legend(fontsize=20,ncol=1,loc='upper left',prop = {'weight':'bold'})
            print('putting legend togeather')
            leg = axes[r].legend(loc='upper left',prop = {'size':25,'weight':'bold'})
            # plt.rc('legend',fontsize=25)
            for lh in leg.legendHandles: 
                
                lh.set_alpha(0.9)
        plt.xlim(-1,13)
        plt.ylim(0,0.79)

        r=r+1
    n=n+1
    plt.tight_layout()

col_list = ['year','City','GCM','RCP','Yield']

lowfreq=DSSAT[col_list]
lowfreq_gpp=lowfreq.groupby(['RCP', 'year','GCM','City']).agg('mean').reset_index(level=[0,1,2,3])
lowfreq_gpp['RCP'][lowfreq_gpp["RCP"] =='R26']="SSP1-2.6" 
lowfreq_gpp['RCP'][lowfreq_gpp["RCP"] =='R70']="SSP3-7.0" 
lowfreq_gpp['RCP'][lowfreq_gpp["RCP"] =='R85']="SSP5-8.5" 

allcities=list(set(lowfreq["City"].tolist()))
allcities=['California','Foggia','Emilia', 'Gansu', 'Xinjiang',   'Inner Mongolia']

allRCP= ["SSP1-2.6","SSP1-2.6","SSP1-2.6"]
allGCM= ["1","2","3","4","5"]
class ScalarFormatterForceFormat(ScalarFormatter):
    def _set_format(self):  # Override function that finds format to use.
        self.format = "%0.1f"  # Give format here

F=36
B=800
plt.rcParams['xtick.labelsize'] = 34
plt.rcParams['ytick.labelsize'] = 36
plt.rcParams['axes.labelsize'] = 26
plt.rcParams['axes.labelsize'] = 26
plt.rcParams['axes.titleweight']='bold'
plt.rcParams["font.weight"] = B
plt.rcParams["axes.labelweight"] = B

fig2, axes = plt.subplots(3, 2,figsize=[20,30.5],sharex=True,sharey=True,dpi=300)
fig2.subplots_adjust(hspace=0.05, wspace=0.05)
axs = axes.ravel()
fig4, axes = plt.subplots(2, 3,figsize=[30.5,21],sharex=True,sharey=True,dpi=300)
fig4.subplots_adjust(hspace=0.05, wspace=0.05)
axxss = axes.ravel()
TRSH=4  ################## SETTING MIN YIELD THRESHOLD ###################
c=0
maxs=[]
for City in allcities:
    if City=='California':
        lbl= 'United States (CA)'
    if City=='Emilia':
        lbl= 'Italy (ER)'
    if City=='Foggia':
        lbl= 'Italy (FG)'
    if City=='Gansu':
        lbl= 'China (GA)'
    if City=='Inner Mongolia':
        lbl= 'China (IM)'       
    if City=='Xinjiang':
        lbl= 'China (XJ)'       
    Citydata_1=lowfreq_gpp.loc[lowfreq_gpp['City']==City]

    if len(Citydata_1) < 1800:
        Citydata_1.set_index(['year','RCP','GCM'],inplace=True)
        years_all = pd.Series(range(1980,2100),name='year') ## generate new index from 1980 to 2100
        (date_index, RCP_index,GMC_index) = Citydata_1.index.levels ## get the indexes of dataframe
        new_index = pd.MultiIndex.from_product([years_all,  RCP_index,GMC_index]) ## generate new multiindex
        Citydata_1_fill = Citydata_1.reindex(new_index) ## replace the indexes
        Citydata_1_fill['City'].ffill(inplace=True) ## fill the city name
        Citydata_1_fill.reset_index(inplace=True)
        Citydata_1_fill['Yield'] =  Citydata_1_fill.groupby(['RCP','year'])['Yield'].transform(lambda x: x.fillna(x.mean()))##fill na values with means of all GCMs whitin each group
        Citydata_1=Citydata_1_fill.copy()

    TRSH=Citydata_1['Yield'].quantile(0.1)    
    basemean=Citydata_1[Citydata_1['year'] < 2010]['Yield'].mean()

    Citydata_1['count'] = np.where(Citydata_1['Yield'] < TRSH, 1, 0)
    Citydata_1['Yeild_count'] = np.where(Citydata_1['count'] == 0, 0, Citydata_1['Yield'] )
    Citydata_1['Yield_from_mean'] = np.where(Citydata_1['Yeild_count'] == 0, 0, 100*(Citydata_1['Yeild_count']-TRSH)/TRSH)
    Citydata_count =Citydata_1.groupby(['RCP','GCM']).agg({'count':'sum'}).reset_index()

    pvdata=Citydata_count.pivot(index='RCP', columns='GCM', values='count')
    plt.rcParams['xtick.labelsize'] = 35
    plt.rcParams['ytick.labelsize'] = 35
    colors=['#fcde9c','#faa476','#f0746e','#e34f6f','#dc3977']

    axs[c]=pvdata.plot(kind='bar',ax=axs[c],legend=False,rot=0,edgecolor='black',color=colors, width=0.5)
    axs[c].set_ylabel('')
    axs[c].set_xlabel('')
    if c==5:
        axs[c].legend(bbox_to_anchor=(0.80, -0.15),ncol=5,title="GCM",fontsize=40)
        axs[c].get_legend().get_title().set_fontsize('40')
        fig2.text(0.01,0.5, 'Total number of low yeilding years',rotation='vertical', va='center', size=60)
    axs[c].annotate(lbl, xy=(0.02, 0.93), xycoords='axes fraction',fontsize=38)
    axs[c].grid()
    plt.setp(axs[c].spines.values(), linewidth=2)

    Citydata_2 =Citydata_1.groupby([(Citydata_1['year']//10)*10,'RCP','GCM']).agg({'count':'sum', 'Yield_from_mean':'mean'}).reset_index()
    Citydata_3 =Citydata_2.groupby(['year','RCP']).agg({'count':'mean', 'Yield_from_mean':'mean'}).reset_index()
    Citydata_3["Yield_from_mean"]=Citydata_3["Yield_from_mean"]*-1
    Citydata_3['Yield_from_mean'][Citydata_3['Yield_from_mean'] < .000001] = 0
    Citydata_3['range']="["+Citydata_3['year'].astype(str)+"-"+(Citydata_3['year']+10).astype(str)+"]"

    groups = Citydata_3.groupby('RCP')
    max_size=Citydata_3['Yield_from_mean'].max()
    maxs.append(max_size)
    colors=['#E68310','#008695','#CF1C90']

    for i, (name, group) in enumerate(groups):
        group.plot(kind='line', x='range', y='count', ax=axxss[c],label='',lw=4, alpha=0.5,color=colors[i],rot=90)
        group.plot(kind='scatter', x='range', y='count', s=group['Yield_from_mean']*20,alpha=0.6,label=name, ax=axxss[c], color=colors[i],rot=90)
        axxss[c].get_legend().remove()

    axxss[c].set_xticks(group['range'][::2])
    axxss[c].set_xticklabels(group['range'][::2], rotation=90)

    if c==5:

        print(max(maxs))         
        pws = (pd.cut(pd.Series(np.arange(0,55,5)) , bins=10, retbins=True)[1]).round(0)
        for pw in pws:
            axxss[c].scatter([], [], s=(pw)*20, color="lightgrey", edgecolor='black',lw=2,label=str(pw))   
            
        
        h, l = axxss[c].get_legend_handles_labels()
        legend1=axxss[c].legend(h[4:], l[4:], bbox_to_anchor=(1.05, 1.05),labelspacing=1,ncol=1, title="Yield decrease (%)", borderpad=.01, 
                frameon=False, handletextpad=0.5, numpoints=1,fontsize=28)
        axxss[c].get_legend().get_title().set_fontsize('28')
        axxss[c].add_artist(legend1)

        lgnd=axxss[c].legend(h[:3], l[:3],bbox_to_anchor=(1.5, 2),labelspacing=1,ncol=1,  borderpad=.01,
                        frameon=False, handletextpad=0.5, numpoints=1,fontsize=30)
        for handle in lgnd.legendHandles:
            handle.set_sizes([600.0])

    axxss[c].set_ylabel('')
    axxss[c].set_xlabel('')  
    if c==5:

        fig4.text(0.07,0.5, 'Low yielding years per decade',rotation='vertical', va='center', size=55)
    axxss[c].annotate(lbl, xy=(0.02, 0.93), xycoords='axes fraction',fontsize=38)
    axxss[c].grid()
    c=c+1


# =============================================================================
# =============================================================================
# #Figure 4 --- Box plots 
# =============================================================================
# =============================================================================
from scipy import stats

col_list_PR = ['year','Country','CNCT','GCM','RCP','Yield','IR','ET','Tr']
DSSAT_subset= DSSAT[col_list_PR]

grouped_data=DSSAT_subset.groupby(['RCP', 'year','GCM','Country']).agg('mean').reset_index(level=[0,1,2,3])
grouped_data["WP_Irr"]=grouped_data['Yield']/grouped_data['IR']
grouped_data["WP_WUE"]=grouped_data['Yield']/grouped_data['ET']
grouped_data["WP_Tr"]=grouped_data['Yield']/grouped_data['Tr']

grouped_data['Period']="2070-2099"
grouped_data['Period'][grouped_data["year"] <2010]="1980-2009" 
grouped_data['Period'][(grouped_data["year"] >= 2010)&(grouped_data["year"] <2040)]="2010-2039"
grouped_data['Period'][(grouped_data["year"] >= 2040)&(grouped_data["year"] <2070)]="2040-2069" 

grouped_data['Country']=grouped_data['Country'].str.strip().replace(dict(zip(["It", "US", "Ch"], ["Italy","United States","China"])),regex=True)

q=grouped_data["WP_Tr"].quantile(0.90)
grouped_data_F=grouped_data[grouped_data["WP_Tr"]< q]
grouped_data_F2=grouped_data_F[(np.abs(stats.zscore(grouped_data_F['WP_Irr'])) < 1.5)]

Italy=grouped_data_F2.loc[grouped_data_F2['Country']=='Italy']
USA=grouped_data_F2.loc[grouped_data_F2['Country']=='United States']
China=grouped_data_F2.loc[grouped_data_F2['Country']=='China']
grouped_data_F2['RCP_new']="SSP1" 
grouped_data_F2['RCP_new'][grouped_data_F2["RCP"] =='R70']="SSP3" 
grouped_data_F2['RCP_new'][grouped_data_F2["RCP"] =='R85']="SSP5" 
allCN=pd.concat([Italy,USA,China])
productivity=["WP_Irr","WP_WUE","WP_Tr"]


F=36
B=800
plt.rcParams['xtick.labelsize'] = 18
plt.rcParams['ytick.labelsize'] = 18
plt.rcParams['axes.labelsize'] = 26
plt.rcParams['axes.labelsize'] = 26
plt.rcParams['axes.titleweight']='bold'
plt.rcParams["font.weight"] = B
plt.rcParams["axes.labelweight"] = B
cn1='Italy'
cn2='United States'
cn3='China'
lb1='Italy\n(FG, ER)'
lb2='United States\n(CA)'
lb3='China\n(IM, GA, XJ)'
allCN['RCP_new']="SSP1" 
allCN['RCP_new'][allCN["RCP"] =='R70']="SSP3" 
allCN['RCP_new'][allCN["RCP"] =='R85']="SSP5" 

for prdc in productivity:
    sns.set(style="ticks", palette='Set2')
    sns.set_context("paper", font_scale=2.1, rc={"lines.linewidth": 1.2})
    g=sns.catplot(x="RCP_new", y=prdc,
                col="Period", row = "Country", 
                data=allCN, kind="box", height=4,aspect=1.1,
                width=0.6,fliersize=2.5,showfliers=False, linewidth=1.3,sharex='row',row_order=[cn1,cn2,cn3],
                margin_titles=True,facet_kws=dict(sharex=False, sharey=True),
                         notch=False,orient="v",palette=sns.color_palette(['#FFB14E', '#EA5F94','#0b9681']))
    # sns.despine(trim=True)
    test=['','','',lb1,'','','',lb2,'','','',lb3]
    test=['','','','','','','','','','','','']
    for f, ax in enumerate(g.axes.flat):
        plt.setp(ax.texts, text=test[f],rotation=90, ha='center')

    if prdc=="WP_Irr":
        plt.text(-11.3,65, 'Irrigation Use efficiency (kg DM/ha/mm)', va='center', rotation='vertical', size=26)
        plt.text(2.8,110, lb1, va='center',ha='center', rotation='vertical', size=22)
        plt.text(2.8,68, lb2, va='center',ha='center', rotation='vertical', size=22)
        plt.text(2.8,25, lb3, va='center',ha='center', rotation='vertical', size=22)
        g.set_ylabels("")
        g.set_xticklabels( size = 17)
    if prdc=="WP_WUE":
        plt.text(-11.3,42, 'Water Use efficiency (kg DM/ha/mm)', va='center', rotation='vertical', size=26)
        plt.text(2.8,72, lb1, va='center',ha='center', rotation='vertical', size=22)
        plt.text(2.8,45, lb2, va='center',ha='center', rotation='vertical', size=22)
        plt.text(2.8,15, lb3, va='center',ha='center', rotation='vertical', size=22)
        g.set_ylabels("")
        g.set_xticklabels( size = 17)
    if prdc=="WP_Tr":
        g.set_ylabels("")
        g.set_xticklabels( size = 17)
        plt.text(-11.3,60, 'Transpiration Use efficiency (kg DM/ha/mm)', va='center', rotation='vertical', size=26)
        plt.text(2.8,100, lb1, va='center',ha='center', rotation='vertical', size=22)
        plt.text(2.8,61, lb2, va='center',ha='center', rotation='vertical', size=22)
        plt.text(2.8,21, lb3, va='center',ha='center', rotation='vertical', size=22)

    g.set_xlabels("")
    g.fig.tight_layout()


# =============================================================================
# Figure 5 --- generate 3D figure
# =============================================================================
import scipy.linalg

class ScalarFormatterForceFormat(ScalarFormatter):
    def _set_format(self):  # Override function that finds format to use.
        self.format = "%0.0f"  # Give format here
YTR['RCP']=YTR['RCP'].str.strip().replace(dict(zip(["R26", "R70", "R85"], ["RCP 2.6","RCP 7.0","RCP 8.5"])),regex=True)

RCPs=[("RCP 2.6","SSP1-2.6"),("RCP 7.0","SSP3-7.0"),("RCP 8.5","SSP5-8.5")]
# RCPs = ["SSP1-2.6","SSP3-7.0","SSP5-8.5"]
for rcps in RCPs:
    GMC1_1_1=YTR.loc[YTR['RCP']==rcps[0]]
    GMC1_1_1=GMC1_1_1.groupby(['year','LAT',"City"]).agg('mean').reset_index(level=[0,1,2])
    GMC1_1_1.dtypes
    x=GMC1_1_1['LAT'].astype(float).values
    y=GMC1_1_1['Tmean'].values
    z=GMC1_1_1['Yield'].values
    
    data = np.c_[x,y,z]
    # regular grid covering the domain of the data
    mn = np.min(data, axis=0)
    mx = np.max(data, axis=0)
    X,Y = np.meshgrid(np.linspace(mn[0], mx[0], 20), np.linspace(mn[1], mx[1], 20))
    XX = X.flatten()
    YY = Y.flatten()
    
    A = np.c_[np.ones(data.shape[0]), data[:,:2], np.prod(data[:,:2], axis=1), data[:,:2]**2]
    C,_,_,_ = scipy.linalg.lstsq(A, data[:,2])

    Z = np.dot(np.c_[np.ones(XX.shape), XX, YY, XX*YY, XX**2, YY**2], C).reshape(X.shape)
    
    plt.rcParams['xtick.labelsize'] = 30
    plt.rcParams['ytick.labelsize'] = 30
    plt.rcParams['axes.labelsize'] = 40
    plt.rcParams['axes.labelsize'] = 40
    plt.rcParams["font.weight"] = "bold"
    plt.rcParams["axes.labelweight"] = "bold"

    fig = plt.figure(figsize=(15, 15))
    ax = fig.gca(projection='3d')
    ax.plot_surface(X, Y, Z,cmap='cool', rstride=1, cstride=2, alpha=0.8)
    ax.scatter(data[:,0], data[:,1], data[:,2], c='purple',edgecolor="darkorange" ,s=5,alpha=0.5)

    ax.axis('tight')
    ax.set_xlabel('Latitude',labelpad=20)
    ax.set_ylabel('Mean air temperature ($^oC$)',labelpad=20, fontsize=35)
    ax.set_zlabel('Yield (t DM/ha)',labelpad=20, fontsize=38)
    yfmt = ScalarFormatterForceFormat()
    ax.yaxis.set_major_formatter(yfmt)
    ax.set_ylim(15,32);
    plt.title(""+rcps[1]+" scenario",fontsize=45,pad=25)
    plt.tight_layout()

